/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

#define FE_SIZE 10

typedef uint32_t FieldElement[10];
typedef uint64_t FieldElementWide[10];

#define fe_copy(dst, src)                 \
    do                                    \
    {                                     \
        for (int i = 0; i < FE_SIZE; ++i) \
        {                                 \
                                          \
            (dst)[i] = (src)[i];          \
        }                                 \
    } while(0)

DECLSPEC inline void fe_zero(FieldElement x)
{
    for(int i = 0; i < 10; ++i)
        x[i] = 0;
}

DECLSPEC inline void fe_one(FieldElement x)
{
    x[0] = 1;
    for(int i = 1; i < 10; ++i)
        x[i] = 0;
}

DECLSPEC inline void fe_square_and_reduce(FieldElement z, const FieldElement x)
{
    uint64_t LOW_25_BITS = (1ul << 25) - 1;
    uint64_t LOW_26_BITS = (1ul << 26) - 1;

    uint64_t c0, c1, t;

#   define m(a, b) ((uint64_t)(a) * (uint64_t)(b))

    uint32_t x0_2   =  2 * x[0];
    uint32_t x1_2   =  2 * x[1];
    uint32_t x2_2   =  2 * x[2];
    uint32_t x3_2   =  2 * x[3];
    uint32_t x4_2   =  2 * x[4];
    uint32_t x5_2   =  2 * x[5];
    uint32_t x6_2   =  2 * x[6];
    uint32_t x7_2   =  2 * x[7];
    uint32_t x5_19  = 19 * x[5];
    uint32_t x6_19  = 19 * x[6];
    uint32_t x7_19  = 19 * x[7];
    uint32_t x8_19  = 19 * x[8];
    uint32_t x9_19  = 19 * x[9];

    c0 = m(x[0],x[0]) + m(x2_2,x8_19) + m(x4_2,x6_19) + (m(x1_2,x9_19) + m(x3_2,x7_19) + m(x[5],x5_19))*2;
    c1 = c0 >> 26;
    z[0] = c0 & LOW_26_BITS;

    c1 += m(x0_2,x[1]) + m(x3_2,x8_19) + m(x5_2,x6_19) + (m(x[2],x9_19) + m(x[4],x7_19))*2;
    c0 = c1 >> 25;
    z[1] = c1 & LOW_25_BITS;

    c0 += m(x0_2,x[2]) + m(x1_2,x[1]) + m(x4_2,x8_19) + m(x[6],x6_19) + (m(x3_2,x9_19) + m(x5_2,x7_19))*2;
    c1 = c0 >> 26;
    z[2] = c0 & LOW_26_BITS;

    c1 += m(x0_2,x[3]) + m(x1_2,x[2]) + m(x5_2,x8_19) + (m(x[4],x9_19) + m(x[6],x7_19))*2;
    c0 = c1 >> 25;
    z[3] = c1 & LOW_25_BITS;

    c0 += m(x0_2,x[4]) + m(x1_2,x3_2) + m(x[2],x[2]) + m(x6_2,x8_19) + (m(x5_2,x9_19) + m(x[7],x7_19))*2;
    c1 = c0 >> 26;
    z[4] = c0 & LOW_26_BITS;

    c1 += m(x0_2,x[5]) + m(x1_2,x[4]) + m(x2_2,x[3]) + m(x7_2,x8_19) + m(x[6],x9_19)*2;
    c0 = c1 >> 25;
    z[5] = c1 & LOW_25_BITS;

    c0 += m(x0_2,x[6]) + m(x1_2,x5_2) + m(x2_2,x[4]) + m(x3_2,x[3]) + m(x[8],x8_19) + m(x7_2,x9_19)*2;
    c1 = c0 >> 26;
    z[6] = c0 & LOW_26_BITS;

    c1 += m(x0_2,x[7]) + m(x1_2,x[6]) + m(x2_2,x[5]) + m(x3_2,x[4]) + m(x[8],x9_19)*2;
    c0 = c1 >> 25;
    z[7] = c1 & LOW_25_BITS;

    c0 += m(x0_2,x[8]) + m(x1_2,x7_2) + m(x2_2,x[6]) + m(x3_2,x5_2) + m(x[4],x[4]) + m(x[9],x9_19)*2;
    c1 = c0 >> 26;
    z[8] = c0 & LOW_26_BITS;

    c1 += m(x0_2,x[9]) + m(x1_2,x[8]) + m(x2_2,x[7]) + m(x3_2,x[6]) + m(x4_2,x[5]) ;
    t = z[0] + 19 * (c1 >> 25);
    z[9] = c1 & LOW_25_BITS;

    z[0] = t & LOW_26_BITS;
    z[1] += t >> 26;

#   undef m
}

#define fe_reduce_carry(z, i)             \
    do                                    \
    {                                     \
        if ((i) % 2 == 0)                 \
        {                                 \
            (z)[(i) + 1] += z[(i)] >> 26; \
            (z)[(i)] &= LOW_26_BITS;      \
        }                                 \
        else                              \
        {                                 \
            (z)[(i) + 1] += z[(i)] >> 25; \
            (z)[(i)] &= LOW_25_BITS;      \
        }                                 \
    } while (0)

#define fe_reduce(X, Z)                         \
    do                                          \
    {                                           \
        uint64_t LOW_25_BITS = (1ul << 25) - 1; \
        uint64_t LOW_26_BITS = (1ul << 26) - 1; \
                                                \
        fe_reduce_carry((Z), 0);                \
        fe_reduce_carry((Z), 1);                \
        fe_reduce_carry((Z), 2);                \
        fe_reduce_carry((Z), 3);                \
        fe_reduce_carry((Z), 4);                \
        fe_reduce_carry((Z), 5);                \
        fe_reduce_carry((Z), 6);                \
        fe_reduce_carry((Z), 7);                \
        fe_reduce_carry((Z), 8);                \
        (Z)[0] += 19 * ((Z)[9] >> 25);          \
        (Z)[9] &= LOW_25_BITS;                  \
        fe_reduce_carry((Z), 0);                \
                                                \
        for (int i = 0; i < 10; ++i)            \
        {                                       \
            (X)[i] = (Z)[i];                    \
        }                                       \
    } while (0)

#define fe_add(z, x, y)               \
    do                                \
    {                                 \
        for (int i = 0; i < 10; ++i)  \
        {                             \
            (z)[i] = (x)[i] + (y)[i]; \
        }                             \
    } while (0)

#define fe_sub(z, x, y)                               \
    do                                                \
    {                                                 \
        FieldElementWide _t;                          \
        _t[0] = ((x)[0] + (0x3ffffed << 4)) - (y)[0]; \
        _t[1] = ((x)[1] + (0x1ffffff << 4)) - (y)[1]; \
        _t[2] = ((x)[2] + (0x3ffffff << 4)) - (y)[2]; \
        _t[3] = ((x)[3] + (0x1ffffff << 4)) - (y)[3]; \
        _t[4] = ((x)[4] + (0x3ffffff << 4)) - (y)[4]; \
        _t[5] = ((x)[5] + (0x1ffffff << 4)) - (y)[5]; \
        _t[6] = ((x)[6] + (0x3ffffff << 4)) - (y)[6]; \
        _t[7] = ((x)[7] + (0x1ffffff << 4)) - (y)[7]; \
        _t[8] = ((x)[8] + (0x3ffffff << 4)) - (y)[8]; \
        _t[9] = ((x)[9] + (0x1ffffff << 4)) - (y)[9]; \
                                                      \
        fe_reduce((z), _t);                           \
    } while (0)

DECLSPEC inline void fe_neg(FieldElement z, const FieldElement y)
{
    FieldElementWide t;

    t[0] = (0x3ffffed << 4) - y[0];
    t[1] = (0x1ffffff << 4) - y[1];
    t[2] = (0x3ffffff << 4) - y[2];
    t[3] = (0x1ffffff << 4) - y[3];
    t[4] = (0x3ffffff << 4) - y[4];
    t[5] = (0x1ffffff << 4) - y[5];
    t[6] = (0x3ffffff << 4) - y[6];
    t[7] = (0x1ffffff << 4) - y[7];
    t[8] = (0x3ffffff << 4) - y[8];
    t[9] = (0x1ffffff << 4) - y[9];

    fe_reduce(z, t);
}

#define M64(a, b) ((uint64_t)(a) * (uint64_t)(b))
#define fe_mul_(z, x, y)                                                                                                                                                                                    \
    do                                                                                                                                                                                                      \
    {                                                                                                                                                                                                       \
        FieldElementWide _t;                                                                                                                                                                                \
                                                                                                                                                                                                            \
        uint32_t y1_19 = 19 * y[1];                                                                                                                                                                         \
        uint32_t y2_19 = 19 * y[2];                                                                                                                                                                         \
        uint32_t y3_19 = 19 * y[3];                                                                                                                                                                         \
        uint32_t y4_19 = 19 * y[4];                                                                                                                                                                         \
        uint32_t y5_19 = 19 * y[5];                                                                                                                                                                         \
        uint32_t y6_19 = 19 * y[6];                                                                                                                                                                         \
        uint32_t y7_19 = 19 * y[7];                                                                                                                                                                         \
        uint32_t y8_19 = 19 * y[8];                                                                                                                                                                         \
        uint32_t y9_19 = 19 * y[9];                                                                                                                                                                         \
                                                                                                                                                                                                            \
        uint32_t x1_2 = 2 * x[1];                                                                                                                                                                           \
        uint32_t x3_2 = 2 * x[3];                                                                                                                                                                           \
        uint32_t x5_2 = 2 * x[5];                                                                                                                                                                           \
        uint32_t x7_2 = 2 * x[7];                                                                                                                                                                           \
        uint32_t x9_2 = 2 * x[9];                                                                                                                                                                           \
                                                                                                                                                                                                            \
        _t[0] = M64(x[0], y[0]) + M64(x1_2, y9_19) + M64(x[2], y8_19) + M64(x3_2, y7_19) + M64(x[4], y6_19) + M64(x5_2, y5_19) + M64(x[6], y4_19) + M64(x7_2, y3_19) + M64(x[8], y2_19) + M64(x9_2, y1_19); \
        _t[1] = M64(x[0], y[1]) + M64(x[1], y[0]) + M64(x[2], y9_19) + M64(x[3], y8_19) + M64(x[4], y7_19) + M64(x[5], y6_19) + M64(x[6], y5_19) + M64(x[7], y4_19) + M64(x[8], y3_19) + M64(x[9], y2_19);  \
        _t[2] = M64(x[0], y[2]) + M64(x1_2, y[1]) + M64(x[2], y[0]) + M64(x3_2, y9_19) + M64(x[4], y8_19) + M64(x5_2, y7_19) + M64(x[6], y6_19) + M64(x7_2, y5_19) + M64(x[8], y4_19) + M64(x9_2, y3_19);   \
        _t[3] = M64(x[0], y[3]) + M64(x[1], y[2]) + M64(x[2], y[1]) + M64(x[3], y[0]) + M64(x[4], y9_19) + M64(x[5], y8_19) + M64(x[6], y7_19) + M64(x[7], y6_19) + M64(x[8], y5_19) + M64(x[9], y4_19);    \
        _t[4] = M64(x[0], y[4]) + M64(x1_2, y[3]) + M64(x[2], y[2]) + M64(x3_2, y[1]) + M64(x[4], y[0]) + M64(x5_2, y9_19) + M64(x[6], y8_19) + M64(x7_2, y7_19) + M64(x[8], y6_19) + M64(x9_2, y5_19);     \
        _t[5] = M64(x[0], y[5]) + M64(x[1], y[4]) + M64(x[2], y[3]) + M64(x[3], y[2]) + M64(x[4], y[1]) + M64(x[5], y[0]) + M64(x[6], y9_19) + M64(x[7], y8_19) + M64(x[8], y7_19) + M64(x[9], y6_19);      \
        _t[6] = M64(x[0], y[6]) + M64(x1_2, y[5]) + M64(x[2], y[4]) + M64(x3_2, y[3]) + M64(x[4], y[2]) + M64(x5_2, y[1]) + M64(x[6], y[0]) + M64(x7_2, y9_19) + M64(x[8], y8_19) + M64(x9_2, y7_19);       \
        _t[7] = M64(x[0], y[7]) + M64(x[1], y[6]) + M64(x[2], y[5]) + M64(x[3], y[4]) + M64(x[4], y[3]) + M64(x[5], y[2]) + M64(x[6], y[1]) + M64(x[7], y[0]) + M64(x[8], y9_19) + M64(x[9], y8_19);        \
        _t[8] = M64(x[0], y[8]) + M64(x1_2, y[7]) + M64(x[2], y[6]) + M64(x3_2, y[5]) + M64(x[4], y[4]) + M64(x5_2, y[3]) + M64(x[6], y[2]) + M64(x7_2, y[1]) + M64(x[8], y[0]) + M64(x9_2, y9_19);         \
        _t[9] = M64(x[0], y[9]) + M64(x[1], y[8]) + M64(x[2], y[7]) + M64(x[3], y[6]) + M64(x[4], y[5]) + M64(x[5], y[4]) + M64(x[6], y[3]) + M64(x[7], y[2]) + M64(x[8], y[1]) + M64(x[9], y[0]);          \
                                                                                                                                                                                                            \
        fe_reduce((z), _t);                                                                                                                                                                                 \
    } while (0)
#define fe_mul(z, x, y) fe_mul_((z), (x), (y))

DECLSPEC inline void fe_square_no_reduce(FieldElementWide z, const FieldElement x)
{
#   define m(a, b) ((uint64_t)(a) * (uint64_t)(b))

    uint32_t x0_2   =  2 * x[0];
    uint32_t x1_2   =  2 * x[1];
    uint32_t x2_2   =  2 * x[2];
    uint32_t x3_2   =  2 * x[3];
    uint32_t x4_2   =  2 * x[4];
    uint32_t x5_2   =  2 * x[5];
    uint32_t x6_2   =  2 * x[6];
    uint32_t x7_2   =  2 * x[7];
    uint32_t x5_19  = 19 * x[5];
    uint32_t x6_19  = 19 * x[6];
    uint32_t x7_19  = 19 * x[7];
    uint32_t x8_19  = 19 * x[8];
    uint32_t x9_19  = 19 * x[9];


    z[0] = m(x[0],x[0]) + m(x2_2,x8_19) + m(x4_2,x6_19) + (m(x1_2,x9_19) + m(x3_2,x7_19) + m(x[5],x5_19))*2;
    z[1] = m(x0_2,x[1]) + m(x3_2,x8_19) + m(x5_2,x6_19) + (m(x[2],x9_19) + m(x[4],x7_19))*2;
    z[2] = m(x0_2,x[2]) + m(x1_2,x[1]) + m(x4_2,x8_19) + m(x[6],x6_19) + (m(x3_2,x9_19) + m(x5_2,x7_19))*2;
    z[3] = m(x0_2,x[3]) + m(x1_2,x[2]) + m(x5_2,x8_19) + (m(x[4],x9_19) + m(x[6],x7_19))*2;
    z[4] = m(x0_2,x[4]) + m(x1_2,x3_2) + m(x[2],x[2]) + m(x6_2,x8_19) + (m(x5_2,x9_19) + m(x[7],x7_19))*2;
    z[5] = m(x0_2,x[5]) + m(x1_2,x[4]) + m(x2_2,x[3]) + m(x7_2,x8_19) + m(x[6],x9_19)*2;
    z[6] = m(x0_2,x[6]) + m(x1_2,x5_2) + m(x2_2,x[4]) + m(x3_2,x[3]) + m(x[8],x8_19) + m(x7_2,x9_19)*2;
    z[7] = m(x0_2,x[7]) + m(x1_2,x[6]) + m(x2_2,x[5]) + m(x3_2,x[4]) + m(x[8],x9_19)*2;
    z[8] = m(x0_2,x[8]) + m(x1_2,x7_2) + m(x2_2,x[6]) + m(x3_2,x5_2) + m(x[4],x[4]) + m(x[9],x9_19)*2;
    z[9] = m(x0_2,x[9]) + m(x1_2,x[8]) + m(x2_2,x[7]) + m(x3_2,x[6]) + m(x4_2,x[5]) ;

#   undef m
}

DECLSPEC inline void fe_square(FieldElement z, const FieldElement x)
{
    // FieldElementWide t;
    // fe_square_no_reduce(t, x);
    // fe_reduce(z, t);

    FieldElement t;
    fe_square_and_reduce(t, x);
    for(unsigned i = 0; i < 10; ++i)
        z[i] = t[i];
}

DECLSPEC inline void fe_square2(FieldElement z, const FieldElement x)
{
    FieldElementWide t;
    fe_square_no_reduce(t, x);
    for(int i = 0; i < 10; ++i)
        t[i] += t[i];
    fe_reduce(z, t);
}

DECLSPEC inline void fe_invert(FieldElement z, const FieldElement x)
{
	FieldElement t0;
	FieldElement t1;
	FieldElement t2;
	FieldElement t3;
	int     i;

	fe_square(t0, x);
	fe_square(t1, t0);
	fe_square(t1, t1);
	fe_mul(t1, x, t1);
	fe_mul(t0, t0, t1);
	fe_square(t2, t0);
	fe_mul(t1, t1, t2);
	fe_square(t2, t1);
	for (i = 1; i < 5; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t1, t2, t1);
	fe_square(t2, t1);
	for (i = 1; i < 10; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t2, t2, t1);
	fe_square(t3, t2);
	for (i = 1; i < 20; ++i) {
		fe_square(t3, t3);
	}
	fe_mul(t2, t3, t2);
	for (i = 1; i < 11; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t1, t2, t1);
	fe_square(t2, t1);
	for (i = 1; i < 50; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t2, t2, t1);
	fe_square(t3, t2);
	for (i = 1; i < 100; ++i) {
		fe_square(t3, t3);
	}
	fe_mul(t2, t3, t2);
	for (i = 1; i < 51; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t1, t2, t1);
	for (i = 1; i < 6; ++i) {
		fe_square(t1, t1);
	}
	fe_mul(z, t1, t0);
}

DECLSPEC inline void fe_to_bytes_no_reduce(uint8_t s[32], const FieldElement x)
{
    FieldElement h;
    fe_copy(h, x);

    uint32_t q = (h[0] + 19) >> 26;
    q = (h[1] + q) >> 25;
    q = (h[2] + q) >> 26;
    q = (h[3] + q) >> 25;
    q = (h[4] + q) >> 26;
    q = (h[5] + q) >> 25;
    q = (h[6] + q) >> 26;
    q = (h[7] + q) >> 25;
    q = (h[8] + q) >> 26;
    q = (h[9] + q) >> 25;

    const uint32_t LOW_25_BITS = (1ul << 25) - 1;
    const uint32_t LOW_26_BITS = (1ul << 26) - 1;

    h[0] += 19*q;

    h[1] += h[0] >> 26;
    h[0] = h[0] & LOW_26_BITS;
    h[2] += h[1] >> 25;
    h[1] = h[1] & LOW_25_BITS;
    h[3] += h[2] >> 26;
    h[2] = h[2] & LOW_26_BITS;
    h[4] += h[3] >> 25;
    h[3] = h[3] & LOW_25_BITS;
    h[5] += h[4] >> 26;
    h[4] = h[4] & LOW_26_BITS;
    h[6] += h[5] >> 25;
    h[5] = h[5] & LOW_25_BITS;
    h[7] += h[6] >> 26;
    h[6] = h[6] & LOW_26_BITS;
    h[8] += h[7] >> 25;
    h[7] = h[7] & LOW_25_BITS;
    h[9] += h[8] >> 26;
    h[8] = h[8] & LOW_26_BITS;

    h[9] = h[9] & LOW_25_BITS;

    s[0] = (h[0] >> 0);
    s[1] = (h[0] >> 8);
    s[2] = (h[0] >> 16);
    s[3] = ((h[0] >> 24) | (h[1] << 2));
    s[4] = (h[1] >> 6);
    s[5] = (h[1] >> 14);
    s[6] = ((h[1] >> 22) | (h[2] << 3));
    s[7] = (h[2] >> 5);
    s[8] = (h[2] >> 13);
    s[9] = ((h[2] >> 21) | (h[3] << 5));
    s[10] = (h[3] >> 3);
    s[11] = (h[3] >> 11);
    s[12] = ((h[3] >> 19) | (h[4] << 6));
    s[13] = (h[4] >> 2);
    s[14] = (h[4] >> 10);
    s[15] = (h[4] >> 18);
    s[16] = (h[5] >> 0);
    s[17] = (h[5] >> 8);
    s[18] = (h[5] >> 16);
    s[19] = ((h[5] >> 24) | (h[6] << 1));
    s[20] = (h[6] >> 7);
    s[21] = (h[6] >> 15);
    s[22] = ((h[6] >> 23) | (h[7] << 3));
    s[23] = (h[7] >> 5);
    s[24] = (h[7] >> 13);
    s[25] = ((h[7] >> 21) | (h[8] << 4));
    s[26] = (h[8] >> 4);
    s[27] = (h[8] >> 12);
    s[28] = ((h[8] >> 20) | (h[9] << 6));
    s[29] = (h[9] >> 2);
    s[30] = (h[9] >> 10);
    s[31] = (h[9] >> 18);
}

DECLSPEC inline void fe_to_bytes(uint8_t s[32], const FieldElement x)
{
    FieldElement h;
    uint64_t t[10];

    for(int i = 0; i < 10; ++i) {
        t[i] = x[i];
    }

    fe_reduce(h, t);
    fe_to_bytes_no_reduce(s, h);
}

DECLSPEC inline void fe_from_bytes(FieldElement x, const uint8_t data[32])
{
#define U64I(X, N) ((uint64_t)((X)[N]))

#define load3(b) (U64I(b, 0) | (U64I(b, 1) << 8) | (U64I(b, 2) << 16))
#define load4(b) (U64I(b, 0) | (U64I(b, 1) << 8) | (U64I(b, 2) << 16) | (U64I(b, 3) << 24))

    uint64_t h[10];

    const uint64_t LOW_23_BITS = (uint64_t)(1ul << 23) - 1;
    h[0] =  load4(data);
    h[1] =  load3(data + 4) << 6;
    h[2] =  load3(data + 7) << 5;
    h[3] =  load3(data + 10) << 3;
    h[4] =  load3(data + 13) << 2;
    h[5] =  load4(data + 16);
    h[6] =  load3(data + 20) << 7;
    h[7] =  load3(data + 23) << 5;
    h[8] =  load3(data + 26) << 4;
    h[9] = (load3(data + 29) & LOW_23_BITS) << 2;

    fe_reduce(x, h);

#   undef load3
#   undef load4
#   undef U64
}